# Import base libraries

import numpy as np
import cupy as cp
import os
from os.path import exists, isfile, join
from pathlib import Path
import shutil
import gc
import math


# Import general stopwords

from nltk.corpus import stopwords

# Import TensorLy
import tensorly as tl
import cudf
from cudf import Series
from cuml.feature_extraction.text import CountVectorizer
from cuml.preprocessing.text.stem import PorterStemmer
import cupyx 


#Insert Plotly
import pandas as pd
import time
import pickle

# Import utility functions from other files
from lda.tlda.tlda_wrapper import TLDA 
from lda.tlda import file_operations as fop


# set  seed


# Constants
SAVE_FOLDER = "data_new3"
X_MAT_FILEPATH_PREFIX = f'{SAVE_FOLDER}/x_mat/'
COUNTVECTOR_FILEPATH = f'{SAVE_FOLDER}/countvec.obj'
X_WHITENED_FILEPATH = f'{SAVE_FOLDER}/x_whit.obj'
TLDA_FILEPATH = f'{SAVE_FOLDER}/tlda.obj'
TOP_WORDS_FILEPATH = f'{SAVE_FOLDER}/top_words.csv'
VOCAB_FILEPATH = f'{SAVE_FOLDER}/vocab.csv'
WEIGHTS_FILEPATH = f"{SAVE_FOLDER}/alpha_weights.txt"
X_LABELS_FILEPATH =f"{SAVE_FOLDER}/x_label/"
X_DATE_FILEPATH_PREFIX =  f"data/data_split2/" 
X_PART_FILEPATH_PREFIX =  f"data/data_split3/" 
X_LOC_FILEPATH_PREFIX =   f"data/data_split4/" 

# Device settings: Set backend
backend="cupy"
tl.set_backend(backend)
device = 'cuda'
porter = PorterStemmer()

if backend == "cupy":
    cp.random.seed(seed = 1000)
else:
    np.random.seed(seed = 1000)

# Clean up tweets

def basic_clean(df):
    df['tweets'] = df['tweets'].astype('str')
    df = df.drop_duplicates(subset=["tweets"],keep="first")
    return df

# Remove n-gram stopwords, need to modify the cupy syntax. 
def partial_fit(self , data):
    if(hasattr(self , 'vocabulary_')):
        vocab = self.vocabulary_ # series
    else:
        vocab = Series()
        vocab = cudf.concat([vocab,Series("stop steal")])
        vocab = cudf.concat([vocab,Series("count votes")])
        vocab = cudf.concat([vocab,Series("mail ballots")])
        vocab = cudf.concat([vocab,Series("voter suppression")])
        vocab = cudf.concat([vocab,Series("illegal votes")])
        vocab = cudf.concat([vocab,Series("legal votes")])
        vocab = cudf.concat([vocab,Series("legal vote")])
        vocab = cudf.concat([vocab,Series("illegal vote")])
        vocab = cudf.concat([vocab,Series("trump won")])
        vocab = cudf.concat([vocab,Series("biden won")])
        vocab = cudf.concat([vocab,Series("trump win")])
        vocab = cudf.concat([vocab,Series("cases won")])
        vocab = cudf.concat([vocab,Series("fake news")])
        vocab = cudf.concat([vocab,Series("court case")])
        vocab = cudf.concat([vocab,Series("court cases")])
        vocab = cudf.concat([vocab,Series("case court")])
        vocab = cudf.concat([vocab,Series("cases court")])
        vocab = cudf.concat([vocab,Series("count every legal vote")])
        vocab = cudf.concat([vocab,Series("60 cases")])
        vocab = cudf.concat([vocab,Series("illegal votes")])
        vocab = cudf.concat([vocab,Series("election fraud")])
        vocab = cudf.concat([vocab,Series("voter fraud")])
        vocab = cudf.concat([vocab,Series("dead voter")])
        vocab = cudf.concat([vocab,Series("dead voters")])
        vocab = cudf.concat([vocab,Series("united states")])
        vocab = cudf.concat([vocab,Series("american history")])
        vocab = cudf.concat([vocab,Series("president election")])
        vocab = cudf.concat([vocab,Series("election rigged")])
        vocab = cudf.concat([vocab,Series("fake news media")])
        vocab = cudf.concat([vocab,Series("state court")])
        vocab = cudf.concat([vocab,Series("supreme court")])
        vocab = cudf.concat([vocab,Series("ballot harvesting")])
        vocab = cudf.concat([vocab,Series("election integrity")])
        vocab = cudf.concat([vocab,Series("breaking news")])
        vocab = cudf.concat([vocab,Series("stop count")])
        vocab = cudf.concat([vocab,Series("stop trump")])





    self.fit(data)
    vocab = cudf.concat([vocab,self.vocabulary_])
    
    self.vocabulary_ = vocab.unique()

def tune_filesplit_size_on_IPCA_batch_size(IPCA_batchsize):
    return None


#remove stop words 

stop_words = (stopwords.words('english'))
added_words = ["mail","illegal","thread","say","will","has","by","for","hi","hey","hah","thank","metoo","watch","sexual","doe","biden",
               "said","talk","congrats","congratulations","are","as","i", "time","abus","year","mani","trump","0 ","000","101","gop","joe",
               "me", "my", "myself", "we", "our", "ours", "ourselves", "use","look","movement","assault","100","united","states","win",
               "you", "your", "yours","he","her","him","she","hers","that","harass","whi","feel","say","gt","ballots","media","claims",
               "be","with","their","they're","is","was","been","not","they","womensmarch","way","thi","suppression","evidence","big",
               "it","have",  "one","think",   "thing"    ,"bring","put","well","take","exactli","tell","suprresion","massive","count",
               "good","day","work", "latest","today","becaus","peopl","via","see","timesup","old","ani","realdonaldtrump","ballot","country",
               "call", "wouldnt","wow", "learned","hi"   , "things" ,"thing","can't","can","right","got","show","happened","history","lee",
               "cant","will","go","going","let","would","could","him","his","think","thi","ha","onli","back","president","american",
               "lets","let's","say","says","know","talk","talked","talks","dont","think","watch","right"," 0","tll","76qepeuoyr","percent","79","888","800","75","630",
               "said","something","this","was","has","had","abc","rt","ha","haha","hat","even","happen"," 0 ","timothÃ©e","legally","proof",
               "something","wont","people","make","want","went","goes","people","had","also","ye","still","must","55k","0kolbjasl9","11","12","saw","134","230","starts",
               "person","like","come","from","yet","able","wa","yah","yeh","yeah","onli","ask","give","read","5k","top","60", "set",""        
                     "need","us", "men", "women", "get", "woman", "man", "amp","amp&","yr","yrs",'voter','election','states',
               "https","co","http","votes","voters", "vote", "2020", "voters",'0',"00","80","000 000",'state','soon',"america","american",
               "georgia","michigan"
         ]



stop_words= list(np.append(stop_words,added_words))
CountVectorizer.partial_fit = partial_fit

countvec = CountVectorizer( stop_words = stop_words, # works
                            lowercase = True, # works
                            ngram_range = (1,2),
                            max_df = 15000, # works
                            min_df = 290)


inDir  = "data/data_split" # input('Name of input directory? : ')

# Learning parameters
num_tops = 30 
alpha_0 = 0.1
batch_size_pca = 70000                                                                                                                                                           
batch_size_grad = 30000
n_iter_train    = 500
n_iter_test     = 40
learning_rate   = 0.0007
theta_param = 10.005
ortho_loss_param =15000
smoothing   = 1e-7

# Program controls   
split_files = 0 # DO NOT SET TO 1: we are giving split data!
vocab_build = 1
transform_mean = 1
save_files  = 1
pca_run     = 1
whiten      = 1
stgd        = 1
predict     = 1
coherence   = 0 

# Other globals
num_data_rows = 0

#Start

print("\n\nSTART...")

# Get a list of files in the directory
dl = fop.get_files_in_dir(inDir)

# Split datafiles into smaller files

if split_files == 1:
    print("Splitting files")
    inDir = fop.split_files(
        inDir, 
        os.path.join(
            "EFBatched_clean", 
            "split_files")
    )
    dl = fop.get_files_in_dir(inDir)
    print("Done. Split files located at: {}.\n".format(inDir))
    print("Split files and their filesizes: ")
    fop.print_filesizes(inDir)


# Build the vocabulary
if vocab_build == 1:
    for f in dl:
        print("Beginning vocabulary build: " + f)
        path_in  = os.path.join(inDir,f)
        df = cudf.DataFrame()
        with open(path_in, 'rb') as fi:
            df['tweets'] = pickle.load(fi)

        # basic preprocessing
        df = basic_clean(df)
        countvec.partial_fit(df['tweets'])
        print("End " + f)

        # count rows of data
        num_data_rows += len(df.index)
    pickle.dump(countvec, open(COUNTVECTOR_FILEPATH, 'wb'))
    
    df_voc = cudf.DataFrame({'words':countvec.vocabulary_})
    df_voc.to_csv(VOCAB_FILEPATH)

# Begin by building and tokenizing the vocabulary.

if vocab_build== 0:
    countvec = pickle.load(open(COUNTVECTOR_FILEPATH,'rb'))

tlda = TLDA(num_tops, alpha_0, n_iter_train, n_iter_test, learning_rate, pca_batch_size = batch_size_pca, third_order_cumulant_batch = batch_size_grad,
    gamma_shape = 1.0, smoothing = 1e-5,theta=theta_param, ortho_loss_criterion = ortho_loss_param,random_seed=1000)

if transform_mean==1:
    i=0
    # compute global mean of the vocab frequencies
    vocab = len(countvec.vocabulary_)
    print("right after countvec partial fit vocab\n\n\n: ", vocab)
    tot_len = 0
    for j, f in enumerate(dl):
        i+=1
        if i % 100==0 :
            print("Beginning transform/mean: " + f)
        path_in  = os.path.join(inDir,f)
        mempool = cp.get_default_memory_pool()
        mempool.free_all_blocks()
        pinned_mempool = cp.get_default_pinned_memory_pool()
        pinned_mempool.free_all_blocks()
        
        # read in dataframe 

        df = pd.DataFrame()
        with open(path_in, 'rb') as fi:
            df['tweets'] = pickle.load(fi)
        mask = df['tweets'].str.len() > 10 
        df   = df.loc[mask]
        df   = cudf.from_pandas(df)
        # basic preprocessing
        df   = basic_clean(df)

        X_batch = tl.tensor(countvec.transform(df['tweets']).toarray())
        tlda._partial_fit_first_order(X_batch)

        if save_files == 1:
            pickle.dump(
                X_batch, 
                open(X_MAT_FILEPATH_PREFIX + Path(f).stem + '.obj','wb')
            )
        print("End " + f)

    print("Total length of dataset: {} rows".format(str(tlda.n_documents)))

    pickle.dump(countvec, open(COUNTVECTOR_FILEPATH, 'wb'))
    pickle.dump(tlda, open(TLDA_FILEPATH, 'wb'))
    del X_batch 
    del df
    del mask
    gc.collect()

if vocab_build == 0:
    countvec = pickle.load(open(COUNTVECTOR_FILEPATH,'rb'))
    tlda = pickle.load(open(TLDA_FILEPATH,'rb'))
    vocab = len(countvec.vocabulary_)

gc.collect()


# Run first eigenvalue decomposition

if pca_run == 1:
    t1 = time.time()
    X_batch = None
    for f in dl:
        mempool = cp.get_default_memory_pool()
        mempool.free_all_blocks()            
        pinned_mempool = cp.get_default_pinned_memory_pool()
        pinned_mempool.free_all_blocks()
    
        print("Beginning PCA: " + f)
        if X_batch is None:
            X_batch = pickle.load(
                        open(X_MAT_FILEPATH_PREFIX + Path(f).stem + '.obj','rb')
                        #open(f,'rb')
                    )
                
        else:
            temp = pickle.load(
            open(X_MAT_FILEPATH_PREFIX + Path(f).stem + '.obj','rb')
                        )
                    
            X_batch = cp.append(X_batch,temp,0)
            del temp
            gc.collect()

        
        if X_batch.shape[0] >= 55000:
            mempool = cp.get_default_memory_pool()
            mempool.free_all_blocks()            
            pinned_mempool = cp.get_default_pinned_memory_pool()
            pinned_mempool.free_all_blocks()
            
            print("X batch shape: ", X_batch.shape)

            tlda._partial_fit_second_order(X_batch)
            X_batch=None
            gc.collect()

    t2 = time.time()
    print("PCA and Centering Time: " + str(t2-t1))
    pickle.dump(tlda, open(TLDA_FILEPATH,'wb'))
    del X_batch 

gc.collect()
if pca_run ==0:
    tlda = pickle.load(open(TLDA_FILEPATH,'rb'))

gc.collect()

if whiten == 1:
    t1 = time.time()
    x_whits = []
    for f in dl:
        mempool = cp.get_default_memory_pool()
        mempool.free_all_blocks()            
        pinned_mempool = cp.get_default_pinned_memory_pool()
        pinned_mempool.free_all_blocks()
        print("Beginning TLDA: " + f)
        X_batch = cp.ndarray.get(pickle.load(
                    open(X_MAT_FILEPATH_PREFIX + Path(f).stem + '.obj','rb')
                )
            )
       
        
        mempool = cp.get_default_memory_pool()
        mempool.free_all_blocks()            
        pinned_mempool = cp.get_default_pinned_memory_pool()
        pinned_mempool.free_all_blocks()
        

        x_whits.append(tlda.second_order.transform(tl.tensor(X_batch) - tlda.mean))
        mempool = cp.get_default_memory_pool()
        mempool.free_all_blocks()            
        pinned_mempool = cp.get_default_pinned_memory_pool()
        pinned_mempool.free_all_blocks()

    x_whit = tl.concatenate(x_whits, axis=0)
    print(x_whit.shape)
    pickle.dump(x_whit, open(X_WHITENED_FILEPATH,'wb'))
    t2 = time.time()
 
    print("Whiten time: " + str(t2-t1))

if whiten == 0:
    x_whit= pickle.load(open(X_WHITENED_FILEPATH,'rb'))

# Run stochastic gradient descent for 3rd order moment.
    
if stgd == 1:
    t1 = time.time()
    tlda.third_order.fit(x_whit,verbose=True)
    t2 = time.time()
    tlda_time =str(t2-t1)
    print("TLDA Time: " + tlda_time)

    pickle.dump(tlda, open(TLDA_FILEPATH, 'wb'))

    n_top_words = 20
    df_voc = cudf.DataFrame({'words':countvec.vocabulary_})
    df_voc.to_csv(VOCAB_FILEPATH)

    for k in range(0,num_tops):
        if k ==0:
            t_n_indices   =  tlda.unwhitened_factors[:,k].argsort()[:-n_top_words - 1:-1]
            top_words_LDA = countvec.vocabulary_[t_n_indices]
            top_words_df  = cudf.DataFrame({'words_'+str(k):top_words_LDA}).reset_index(drop=True)
            
        if k >=1:
            t_n_indices   =  tlda.unwhitened_factors[:,k].argsort()[:-n_top_words - 1:-1]
            top_words_LDA = countvec.vocabulary_[t_n_indices]
            top_words_df['words_'+str(k)] = top_words_LDA.reset_index(drop=True)
    top_words_df.to_csv(TOP_WORDS_FILEPATH)
    del df_voc, countvec,top_words_LDA 



if stgd == 0:
        tlda               = pickle.load(open(TLDA_FILEPATH, 'rb'))
        tlda.third_order.n_iter_test = n_iter_test



if predict == 1:
    print("begin predict")
    for f in dl:
        print(f)
        X_batch =   tl.tensor(cp.ndarray.get(pickle.load(open(X_MAT_FILEPATH_PREFIX + Path(f).stem + '.obj','rb'))))

        df = pd.DataFrame()
        path_in  = os.path.join(inDir,f)
        with open(path_in, 'rb') as fi:
            df['tweets'] = pickle.load(fi)
            df_date      = pd.DataFrame(pickle.load(open(X_DATE_FILEPATH_PREFIX + "time_" + Path(f).stem + '.pkl','rb')))
            df["date"]   = df_date.iloc[:, 0]
            df_party     = pd.DataFrame(pickle.load(open(X_PART_FILEPATH_PREFIX + "party_" + Path(f).stem + '.pkl','rb')))
            df["party"]   = df_party.iloc[:, 0]
            df_loc     = pd.DataFrame(pickle.load(open(X_LOC_FILEPATH_PREFIX + "location_" + Path(f).stem + '.pkl','rb')))
            df["location"]   = df_party.iloc[:, 0]
        print(df.shape)    
        mask  = df['tweets'].str.len() > 10 
        df  = df.loc[mask]
        print(df.head())
        df   = basic_clean(df)
   

        print(df.shape)
        X_labs  =   tlda.transform(X_batch, predict=True)
        df_final      =   pd.DataFrame(X_labs.get())
        df_final["tweets"]    = df["tweets"]
        df_final["date"]      = df["date"]  
        df_final["party"]     = df["party"]    
        df_final["location"]  = df["location"]    

        print(df_final.head())
        df_final.to_csv(X_LABELS_FILEPATH + Path(f).stem +".csv")

    outFile = open(WEIGHTS_FILEPATH, 'w')
    print(tlda.weights_, file=outFile)
    print(np.argsort(cp.asnumpy(tlda.weights_))[::-1], file=outFile)
    outFile.close()


if coherence == 1:
    i=1
    for f in dl:             
        print(f)
        X_batch = cupyx.scipy.sparse.csr_matrix( pickle.load(
                open(X_MAT_FILEPATH_PREFIX + Path(f).stem + '.obj','rb')))
        if i == 1 :
            X= X_batch
        else: 
            X       = cupyx.scipy.sparse.vstack([X,X_batch])
            mempool = cp.get_default_memory_pool()
            mempool.free_all_blocks()            
            pinned_mempool = cp.get_default_pinned_memory_pool()
            pinned_mempool.free_all_blocks()

            i +=1
    n = X.shape[0]
    tcm = X.T.dot(X)
    print(tcm.shape)
    numerator   = cupyx.scipy.sparse.triu(tcm, k=1)
    denominator = M1
    print(denominator.shape)
    score       = cp.log(((numerator.toarray()/n)+epsilon)/denominator)
    topic_coh   = []
    for k in range(0,num_tops):
        t_n_indices   = tlda.unwhitened_factors[:,k].argsort()[:-n_top_words - 1:-1]
        score_tmp     = score[cp.ix_(t_n_indices,t_n_indices)]
        topic_coh.append(score_tmp.mean())
 
    u_mass = sum(topic_coh)/k
    print(u_mass)




